import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import json
import bluepysnap
import matplotlib.patches as patches
import numpy as np


def plot_circuit_compo_pie_chart(sizes, ax):
    labels = "Ecel", "Spp", "Runaway"
    # colors = ["springgreen", "darkorchid"]
    colors = ["C0", "C1", "C2"]
    ax.pie(sizes, labels=labels, autopct="%1.1f%%", colors=colors)
    return ax


def plot_Ecel_circuit_compo_chart():
    circuit_compo_rations = [
        [0, 50, 50],
        [25, 37.5, 37.5],
        [50, 25, 25],
        [75, 12.5, 12.5],
        [100, 0, 0],
    ]  # Ecel/Spp percentage
    num_plots_x, num_plots_y = 1, len(circuit_compo_rations)
    plot_size = (15, 2)

    plt.figure(figsize=(plot_size[0], plot_size[1] * num_plots_y))

    for index, r in enumerate(circuit_compo_rations):
        ax = plt.subplot(num_plots_y, num_plots_x, index + 1)
        ax = plot_circuit_compo_pie_chart(r, ax)

    plt.tight_layout()
    plt.savefig("variedEcel_circuit_composition.pdf")


def get_spindle_amp(sim):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    from scipy.signal import find_peaks

    fs = []
    fs1 = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size == 0:
            fs.append(0)
            fs1.append(0)
        else:
            fs.append(np.mean(f[peaks]))
            fs1.append(f[peaks[0]])
    return fs, fs1


def make_plot_PSTH_up_down_RT(sim, ts, te, UP_state_start_times, DOWN_state_start_times, ax):
    RT_gids = sim.circuit.nodes["thalamus_neurons"].ids({"region": "mc2;Rt", "mtype": "Rt_RC"})
    spikes_df = sim.spikes.filter(group=RT_gids, t_start=ts, t_stop=te).report
    spike_times = spikes_df.index.to_numpy()
    spiking_gids = np.unique(spikes_df.ids.to_numpy())

    ax.set_ylabel("PSTH [Hz]")
    ylim = 120
    ax.set_ylim(0, ylim)
    ax.set_xlim(3000, 19000)

    time_binsize = None
    times = 0

    time_start = np.min(spike_times)
    time_stop = np.max(spike_times)

    if time_binsize is None:
        # heuristic for a nice bin size (~100 spikes per bin on average)
        time_binsize = min(10.0, (time_stop - time_start) / ((len(spike_times) / 100.0) + 1.0))

    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(spike_times, bins=bins)

    node_count = len(np.unique(spiking_gids))
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)

    # use the middle of the bins instead of the start of the bin
    ax.plot(0.5 * (bin_edges[1:] + bin_edges[:-1]), freq, drawstyle="steps-mid")
    ax.spines.top.set_visible(False)
    ax.spines.right.set_visible(False)

    for i, t in enumerate(UP_state_start_times):
        points = [(t, -100), (t, ylim), (t + 500, ylim), (t + 500, -100)]
        rect = patches.Polygon(
            points,
            linewidth=1,
            edgecolor="none",
            facecolor="grey",
            alpha=0.2,
        )
        ax.add_patch(rect)

    fs = []
    for low, high in zip(UP_state_start_times, DOWN_state_start_times):
        fs.append(np.max(freq[(bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)]))
    return np.mean(fs), np.std(fs)


if __name__ == "__main__":

    plot_Ecel_circuit_compo_chart()

    sim_root = "../../simulations/08_08_Cav_variants/"

    ts, te = 1500, 20000

    colors = {
        "Ecel": "springgreen",
        "Spp": "darkorchid",
        "Runaway": "purple",
        "intermedEcel": "red",
        "intermedSpp": "blue",
        "mixed50pEcel50pSppRunaway": "green",
        "mixed75pEcel25pSppRunaway": "orange",
        "mixed25pEcel75pSppRunaway": "grey",
    }

    sim_type_list = [
        "variant10_50pSpp50pRunaway",
        "variant9_mixed25pEcel75pSppRunaway",
        "variant7_mixed50pEcel50pSppRunaway",
        "variant8_mixed75pEcel25pSppRunaway",
        "variant1_Ecel",
    ]

    up_start = [
        3500,
        4500,
        5500,
        6500,
        7500,
        8500,
        9500,
        10500,
        11500,
        12500,
        13500,
        14500,
        15500,
        16500,
        17500,
        18500,
    ]
    up_end = [
        4000,
        5000,
        6000,
        7000,
        8000,
        9000,
        10000,
        11000,
        12000,
        13000,
        14000,
        15000,
        16000,
        17000,
        18000,
        19000,
    ]

    # Ecel composition circuits plot params
    num_plots_x, num_plots_y = 1, len(sim_type_list)
    amps = []
    stds = []
    all_amps = []
    amp1 = []
    plot_size = (15, 2)
    fig, axes = plt.subplots(
        num_plots_y, num_plots_x, figsize=(plot_size[0], plot_size[1] * num_plots_y)
    )

    amp_df = pd.DataFrame()
    for sim_id, (ax, sim_type_prefix) in enumerate(zip(axes, sim_type_list)):

        sims_list = ["_90percent_CT_up_down"]  # "_90percent_CT_up_down",

        spike_data = {"mean_num_spikes": [], "percent_CT_activated": [], "ts": ts, "te": te}
        spike_data["ts"] = ts
        spike_data["te"] = te
        for i, sim_folder in enumerate(sims_list):
            sim_folder_full_name = sim_type_prefix + "_SUPER_LONG" + sim_folder
            sim_path = (
                sim_root + sim_type_prefix + "/" + sim_folder_full_name + "/simulation_config.json"
            )
            f = open(sim_path)
            data = json.load(f)

            circuit_var = data["network"].split("/")[-1]
            percent_CT_input = data["inputs"]["spikeReplay_ct_UP1"]["spike_file"].split("_")[-4]

            print(f"\nAnalysing sim for circuit {circuit_var}, percent_CT_input {percent_CT_input}")

            _sim = bluepysnap.Simulation(sim_path)

            mean_freq, mean_freq_eror = make_plot_PSTH_up_down_RT(
                _sim, ts, te, up_start, up_end, ax
            )
            _amps, _amp1 = get_spindle_amp(_sim)
            amp1.append(_amp1)
            all_amps.append(_amps)
            amps.append(np.mean(_amps))
            stds.append(np.std(_amps))
            amp_df[sim_id] = _amps
            print(amp_df)
            print(f"Mean freq is {mean_freq} +- {mean_freq_eror}")

    # ax.legend(loc="upper right", fontsize=13)
    ax.set_xlabel("Time [ms]")
    plt.tight_layout()
    plt.savefig(f"variedEcel_composition_{sim_folder}.pdf")

    comp = [0.0, 0.25, 0.5, 0.75, 1.0]
    plt.figure(figsize=(5, 3))
    amp_df.columns = comp
    # plt.errorbar(comp, amps, yerr=stds)
    sns.boxplot(data=amp_df, native_scale=True, color="k", fill=False, showfliers=False)
    sns.stripplot(data=amp_df, native_scale=True, color="k", size=3)
    plt.xlabel("Ecel fraction")
    plt.ylabel("Spindle amplitude [mV]")
    plt.tight_layout()
    plt.savefig("variedEcel_composition_amp.pdf")
